{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import string\n",
    "data_train_path=\"data\\\\snli\\\\snli_1.0_train.txt\"\n",
    "embedding_file=\"data\\\\embedding\\\\glove.6B.50d.txt\"\n",
    "\n",
    "worddict_path=\"data\\\\worddict.txt\"\n",
    "data_train_str=\"data\\\\train_data_str.pkl\"\n",
    "data_train_id=\"data\\\\train_data_id.pkl\"\n",
    "embedding_matrix=\"data\\\\embedding_matrix.pkl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data(data_path):\n",
    "    premise=[]\n",
    "    hypothesis=[]\n",
    "    labels=[] \n",
    "    labels_map={\"entailment\":0,\"neutral\":1,\"contradiction\":2}\n",
    "    punct_table = str.maketrans({key: \" \" for key in string.punctuation})\n",
    "    with open(data_path,'r',encoding='utf-8') as lines:\n",
    "        next(lines)\n",
    "        for line in lines:\n",
    "            line=line.strip().split('\\t')\n",
    "            if line[0] not in labels_map:   #忽略没有label的例子\n",
    "                continue\n",
    "            premise.append(line[5].translate(punct_table).lower())\n",
    "            hypothesis.append(line[6].translate(punct_table).lower())\n",
    "            labels.append(line[0])\n",
    "    return {\"premise\":premise,\n",
    "            \"hypothesis\":hypothesis,\n",
    "            \"labels\":labels}    \n",
    "def build_worddict(data):\n",
    "    words=[]\n",
    "    words.extend([\"_PAD_\",\"_OOV_\",\"_BOS_\",\"_EOS_\"])\n",
    "    for sentence in data[\"premise\"]:\n",
    "        words.extend(sentence.strip().split(\" \"))\n",
    "    for sentence in data[\"hypothesis\"]:\n",
    "        words.extend(sentence.strip().split(\" \")) \n",
    "    word_id={}\n",
    "    id_word={}\n",
    "    i=0\n",
    "    for index,word in enumerate(words):\n",
    "        if word not in word_id:\n",
    "            word_id[word]=i\n",
    "            id_word[i]=word\n",
    "            i+=1\n",
    "    #保存词典\n",
    "    with open(worddict_path, \"w\",encoding='utf-8') as f:\n",
    "        for word in word_id:\n",
    "            f.write(\"%s\\t%d\\n\"%(word, word_id[word]))\n",
    "    return word_id,id_word\n",
    "\n",
    "def sentence2idList(sentence,word_id):\n",
    "    ids=[]\n",
    "    ids.append(word_id[\"_BOS_\"])\n",
    "    sentence=sentence.strip().split(\" \")\n",
    "    for word in sentence:\n",
    "        if word not in word_id:\n",
    "            ids.append(word_id[\"_OOV_\"])\n",
    "        else:\n",
    "            ids.append(word_id[word])\n",
    "    ids.append(word_id[\"_EOS_\"])\n",
    "    return ids\n",
    "\n",
    "def data2id(data,word_id):\n",
    "    premise_id=[]\n",
    "    hypothesis_id=[]\n",
    "    labels_id=[] \n",
    "    labels_map={\"entailment\":0,\"neutral\":1,\"contradiction\":2}\n",
    "    for i,label in enumerate(data[\"labels\"]):\n",
    "        if label not in labels_map:   #忽略没有label的例子\n",
    "            continue\n",
    "        premise_id.append(sentence2idList(data[\"premise\"][i],word_id))\n",
    "        hypothesis_id.append(sentence2idList(data[\"hypothesis\"][i],word_id))\n",
    "        labels_id.append(labels_map[label])\n",
    "            \n",
    "    return {\"premise_id\":premise_id,\n",
    "            \"hypothesis_id\":hypothesis_id,\n",
    "            \"labels_id\":labels_id}    \n",
    "\n",
    "def build_embeddings(embedding_file,word_id):\n",
    "    #读取文件存入集合中\n",
    "    embeddings_map={}\n",
    "    with open(embedding_file,'r',encoding='utf-8') as f:\n",
    "        for line in f:\n",
    "            line=line.strip().split()\n",
    "            word=line[0]\n",
    "            if word in word_id:\n",
    "                embeddings_map[word]=line[1:]   \n",
    "    #放入矩阵中\n",
    "    words_num = len(word_id)\n",
    "    embedding_dim=len(embeddings_map['a'])\n",
    "    embedding_matrix=np.zeros((words_num,embedding_dim))\n",
    "    #print(words_num,embedding_dim)\n",
    "    missed_cnt=0\n",
    "    for i,word in enumerate(word_id):\n",
    "        if word in embeddings_map:\n",
    "            embedding_matrix[i]=embeddings_map[word]\n",
    "        else:\n",
    "            if word==\"_PAD_\":\n",
    "                continue\n",
    "            missed_cnt+=1\n",
    "            embedding_matrix[i]=np.random.normal(size=embedding_dim)\n",
    "    print(\"missed word count: %d\"%(missed_cnt)) \n",
    "    return embeddings_map,embedding_matrix\n",
    "          \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#读取数据\n",
    "data_str=read_data(data_train_path)\n",
    "#构建词典\n",
    "word_id,id_word=build_worddict(data_str)   \n",
    "#清洗数据并转换为id\n",
    "data_id=data2id(data_str,word_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#保存 data_train_str和data_train_id\n",
    "with open(data_train_str,\"wb\") as f:\n",
    "    pickle.dump(data_str,f)\n",
    "with open(data_train_id,\"wb\") as f:\n",
    "    pickle.dump(data_id,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "missed word count: 5994\n",
      "33268\n",
      "27273\n"
     ]
    }
   ],
   "source": [
    "embeddings_map,embedding_matrix=build_embeddings(embedding_file,word_id)\n",
    "print(len(embedding_matrix))\n",
    "print(len(embeddings_map))\n",
    "with open(data_train_id,\"wb\") as f:\n",
    "    pickle.dump(data_id,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python36",
   "language": "python",
   "name": "python36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
